# Some test dataset and evaluation method are ref from https://github.com/OSU-NLP-Group/HippoRAG/tree/main/data , many thanks

################################################################################
# Note: This evaluation script will cost a lot of LLM token usage, please make sure you have enough token budget.
################################################################################
import argparse
import ast
import json
import logging
import os
import time
import warnings
from typing import List, Tuple

import pandas as pd

from deepsearcher.configuration import Configuration, init_config
from deepsearcher.offline_loading import load_from_local_files
from deepsearcher.online_query import naive_retrieve, retrieve

httpx_logger = logging.getLogger("httpx")  # disable openai's logger output
httpx_logger.setLevel(logging.WARNING)


warnings.simplefilter(action="ignore", category=FutureWarning)  # disable warning output


current_dir = os.path.dirname(os.path.abspath(__file__))

k_list = [2, 5]


def _deepsearch_retrieve_titles(
    question: str, retry_num: int = 4, base_wait_time: int = 4
) -> Tuple[List[str], bool]:
    retrieved_results = []
    for i in range(retry_num):
        try:
            retrieved_results, _, _ = retrieve(question)
            break
        except Exception:
            wait_time = base_wait_time * (2**i)
            print(f"Parse LLM's output failed, retry again after {wait_time} seconds...")
            time.sleep(wait_time)
    if retrieved_results:
        retrieved_titles = [
            retrieved_result.metadata["title"] for retrieved_result in retrieved_results
        ]
        fail = False
    else:
        print("Pipeline error, no retrieved results.")
        retrieved_titles = []
        fail = True
    return retrieved_titles, fail


def _naive_retrieve_titles(question: str) -> List[str]:
    retrieved_results = naive_retrieve(question)
    retrieved_titles = [
        retrieved_result.metadata["title"] for retrieved_result in retrieved_results
    ]
    return retrieved_titles


def _calcu_recall(sample, retrieved_titles, dataset) -> dict:
    if dataset in ["2wikimultihopqa"]:
        gold_passages = [item for item in sample["supporting_facts"]]
        gold_items = set([item[0] for item in gold_passages])
        retrieved_items = retrieved_titles
    else:
        raise NotImplementedError

    recall = dict()
    for k in k_list:
        recall[k] = round(
            sum(1 for t in gold_items if t in retrieved_items[:k]) / len(gold_items), 4
        )
    return recall


def _print_recall_line(recall: dict, pre_str="", post_str="\n"):
    print(pre_str, end="")
    for k in k_list:
        print(f"R@{k}: {recall[k]:.3f} ", end="")
    print(post_str, end="")


def evaluate(
    dataset: str,
    output_root: str,
    pre_num: int = 10,
    skip_load=False,
    flag: str = "result",
):
    corpus_file = os.path.join(current_dir, f"../examples/data/{dataset}_corpus.json")
    if not skip_load:
        # set chunk size to a large number to avoid chunking, because the dataset was chunked already.
        load_from_local_files(
            corpus_file, force_new_collection=True, chunk_size=999999, chunk_overlap=0
        )

    eval_output_subdir = os.path.join(output_root, flag)
    os.makedirs(eval_output_subdir, exist_ok=True)
    csv_file_path = os.path.join(eval_output_subdir, "details.csv")

    data_with_gt_file_path = os.path.join(current_dir, f"../examples/data/{dataset}.json")
    data_with_gt = json.load(open(data_with_gt_file_path, "r"))

    if not pre_num:
        pre_num = len(data_with_gt)

    pipeline_error_num = 0
    end_ind = min(pre_num, len(data_with_gt))

    start_ind = 0
    existing_df = pd.DataFrame()
    if os.path.exists(csv_file_path):
        existing_df = pd.read_csv(csv_file_path)
        start_ind = len(existing_df)
        print(f"Loading results from {csv_file_path}, start_index = {start_ind}")

    for sample_idx, sample in enumerate(data_with_gt[start_ind:end_ind]):
        global_idx = sample_idx + start_ind
        question = sample["question"]

        retrieved_titles, fail = _deepsearch_retrieve_titles(question)
        retrieved_titles_naive = _naive_retrieve_titles(question)

        if fail:
            pipeline_error_num += 1
            print(
                f"Pipeline error, no retrieved results. Current pipeline_error_num = {pipeline_error_num}"
            )

        print(f"idx: {global_idx}: ")
        recall = _calcu_recall(sample, retrieved_titles, dataset)
        recall_naive = _calcu_recall(sample, retrieved_titles_naive, dataset)
        current_result = [
            {
                "idx": global_idx,
                "question": question,
                "recall": recall,
                "recall_naive": recall_naive,
                "gold_titles": [item[0] for item in sample["supporting_facts"]],
                "retrieved_titles": retrieved_titles,
                "retrieved_titles_naive": retrieved_titles_naive,
            }
        ]
        current_df = pd.DataFrame(current_result)
        existing_df = pd.concat([existing_df, current_df], ignore_index=True)
        existing_df.to_csv(csv_file_path, index=False)
        average_recall = dict()
        average_recall_naive = dict()
        for k in k_list:
            average_recall[k] = sum(
                [
                    ast.literal_eval(d).get(k) if isinstance(d, str) else d.get(k)
                    for d in existing_df["recall"]
                ]
            ) / len(existing_df)
            average_recall_naive[k] = sum(
                [
                    ast.literal_eval(d).get(k) if isinstance(d, str) else d.get(k)
                    for d in existing_df["recall_naive"]
                ]
            ) / len(existing_df)
        _print_recall_line(average_recall, pre_str="Average recall of DeepSearcher: ")
        _print_recall_line(average_recall_naive, pre_str="Average recall of naive RAG   : ")
        print("")
    print("Finish results to save.")


def main_eval():
    parser = argparse.ArgumentParser(prog="evaluate", description="Deep Searcher evaluation.")
    parser.add_argument(
        "--dataset",
        type=str,
        default="2wikimultihopqa",
        help="Dataset name, default is `2wikimultihopqa`. More datasets will be supported in the future.",
    )
    parser.add_argument(
        "--config_yaml",
        type=str,
        default="./eval_config.yaml",
        help="Configuration yaml file path, default is `./eval_config.yaml`",
    )
    parser.add_argument(
        "--pre_num",
        type=int,
        default=30,
        help="Number of samples to evaluate, default is 30",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./eval_output",
        help="Output root directory, default is `./eval_output`",
    )
    parser.add_argument(
        "--skip_load",
        action="store_true",
        help="Whether to skip loading the dataset. Default it don't skip loading. If you want to skip loading, please set this flag.",
    )
    parser.add_argument(
        "--flag",
        type=str,
        default="result",
        help="Flag for evaluation, default is `result`",
    )

    args = parser.parse_args()

    config = Configuration(config_path=args.config_yaml)
    init_config(config=config)

    evaluate(
        dataset=args.dataset,
        output_root=args.output_dir,
        pre_num=args.pre_num,
        skip_load=args.skip_load,
        flag=args.flag,
    )


if __name__ == "__main__":
    main_eval()
